import os
import logging
from flask import Flask, request
from config import (
    LLM_GPU_DEVICE,
    LLM_MODEL_PATH,
    LLM_SERVER_HOST,
    LLM_SERVER_PORT
)
from models import LargeLanguageModel
os.environ["CUDA_VISIBLE_DEVICES"] = LLM_GPU_DEVICE


app = Flask(__name__)
prompt_gen_model = LargeLanguageModel(LLM_MODEL_PATH)


@app.route("/", methods=["GET"])
def index():
    return {
        "code": "000",
        "msg": "success"
    }


@app.route("/llm/generate/", methods=["POST"])
def prompt_generate():
    data = request.json
    resp = prompt_gen_model.process(
        data["context"],
        data.get("max_length", 2048),
        data.get("top_p", 0.35),
        data.get("temperature", 0.8)
    )
    resp = resp.replace(data["context"], "", 1)
    return {
        "code": "000",
        "msg": "success",
        "data": resp
    }


if __name__ == "__main__":
    app.run(
        LLM_SERVER_HOST,
        port=LLM_SERVER_PORT,
        debug=False
    )
